package com.flow.framework.web.filter;

import com.flow.framework.base.constant.FrameworkBaseConstant;
import com.flow.framework.base.properties.FrameworkBaseConfigProperties;
import com.flow.framework.base.properties.component.RequestConfigProperties;
import com.flow.framework.base.properties.component.ResponseConfigProperties;
import com.flow.framework.base.service.access.log.IAccessLogService;
import com.flow.framework.common.constant.FrameworkCommonConstant;
import com.flow.framework.common.error.SystemErrorCode;
import com.flow.framework.common.exception.CheckedException;
import com.flow.framework.common.json.JsonObject;
import com.flow.framework.common.util.collection.CollectionUtil;
import com.flow.framework.common.util.io.IoUtil;
import com.flow.framework.common.util.random.RandomUtil;
import com.flow.framework.common.util.verify.VerifyUtil;
import com.flow.framework.core.constant.FrameworkCoreConstant;
import com.flow.framework.core.holder.OptLogI18nContextHolder;
import com.flow.framework.core.holder.SecurityContextHolder;
import com.flow.framework.core.holder.SystemVersionContextHolder;
import com.flow.framework.core.response.Response;
import com.flow.framework.core.util.HttpContentTypeUtil;
import com.flow.framework.facade.access.log.constant.AccessLogFacadeConstant;
import com.flow.framework.facade.access.log.opt.annotation.OptLog;
import com.flow.framework.web.context.RequestResponseContext;
import com.flow.framework.web.helper.AccessLogHelper;
import com.flow.framework.web.holder.RequestResponseContextHolder;
import com.flow.framework.web.util.HttpServletUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.MDC;
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import org.springframework.web.util.WebUtils;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
 * web全局过滤器
 *
 * @author luoguopiao
 * @version 0.0.1
 * @date 2022/12/25
 */
@Order(Ordered.HIGHEST_PRECEDENCE)
@WebFilter(urlPatterns = "/*")
@Slf4j
@RefreshScope
@RequiredArgsConstructor
public class RequestResponseBodyFilter extends OncePerRequestFilter {

    private final IAccessLogService accessLogService;

    private final AccessLogHelper accessLogHelper;

    private final FrameworkBaseConfigProperties frameworkBaseConfigProperties;

    /**
     * The dispatcher type {@code javax.servlet.DispatcherType.ASYNC} introduced
     * in Servlet 3.0 means a filter can be invoked in more than one thread
     * over the course of a single request. Some filters only need to filter
     * the initial thread (e.g. request wrapping) while others may need
     * to be invoked at least once in each additional thread for example for
     * setting up thread locals or to perform final processing at the very end.
     * <p>Note that although a filter can be mapped to handle specific dispatcher
     * types via {@code web.xml} or in Java through the {@code ServletContext},
     * servlet containers may enforce different defaults with regards to
     * dispatcher types. This flag enforces the design intent of the filter.
     * <p>The default return value is "true", which means the filter will not be
     * invoked during subsequent async dispatches. If "false", the filter will
     * be invoked during async dispatches with the same guarantees of being
     * invoked only once during a request within a single thread.
     *
     * @since 3.2
     */
    @Override
    protected boolean shouldNotFilterAsyncDispatch() {
        return false;
    }

    /**
     * Whether to filter error dispatches such as when the servlet container
     * processes and error mapped in {@code web.xml}. The default return value
     * is "true", which means the filter will not be invoked in case of an error
     * dispatch.
     *
     * @since 3.2
     */
    @Override
    protected boolean shouldNotFilterErrorDispatch() {
        return false;
    }

    /**
     * Forwards the request to the next filter in the chain and delegates down to the subclasses
     * to perform the actual request logging both before and after the request is processed.
     *
     * @see #beforeRequest
     * @see #afterRequest
     */
    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        // 清空上下文信息
        MDC.clear();
        SecurityContextHolder.clearAll();
        SystemVersionContextHolder.clear();
        OptLogI18nContextHolder.clear();
        try {
            String traceId = request.getHeader(FrameworkCommonConstant.GLOBAL_LOG_TRACE_KEY);
            if (VerifyUtil.isEmpty(traceId)) {
                traceId = RandomUtil.random20LenId();
            }
            MDC.put(FrameworkCommonConstant.GLOBAL_LOG_TRACE_KEY, traceId);
            customizationDoFilterInternal(request, response, filterChain);
        } finally {
            // 清空上下文信息
            MDC.clear();
            SecurityContextHolder.clearAll();
            SystemVersionContextHolder.clear();
            OptLogI18nContextHolder.clear();
            RequestResponseContextHolder.clear();
        }
    }

    private void customizationDoFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        long startTime = System.currentTimeMillis();
        boolean isFirstRequest = !isAsyncDispatch(request);
        HttpServletRequest requestToUse = request;
        HttpServletResponse responseToUse = response;
        String method = request.getMethod();
        String uri = request.getRequestURI();
        RequestConfigProperties requestConfigProperties = frameworkBaseConfigProperties.getRequest();
        ResponseConfigProperties responseConfigProperties = frameworkBaseConfigProperties.getResponse();

        // 设置日志记录标记
        boolean enableRecordPersistence = requestConfigProperties.isEnableRecordPersistence() && isRequireRecordLogUri(uri, method);
        boolean enableRecordLocal = requestConfigProperties.isEnableRecordLocal() && isRequireRecordLogUri(uri, method);
        OptLog optLog = accessLogService.matchOptLogAnnotation(method, uri);
        boolean isRecordAccessLog = false;
        if (enableRecordLocal || null != optLog || enableRecordPersistence) {
            int requestMaxLen = Math.max(requestConfigProperties.getRecordLocalMaxLength(), requestConfigProperties.getRecordPersistenceMaxLength());
            int responseMaxLen = Math.max(responseConfigProperties.getRecordLocalMaxLength(), responseConfigProperties.getRecordPersistenceMaxLength());
            String requestContentType = request.getContentType();
            boolean wrapperRequestResponse = (isFirstRequest && !(request instanceof ContentCachingRequestWrapper)
                    && HttpContentTypeUtil.isTextBody(requestContentType)) || VerifyUtil.isEmpty(requestContentType);
            if (wrapperRequestResponse) {
                requestToUse = new ContentCachingRequestWrapper(request, requestMaxLen);
                responseToUse = new ContentCachingResponseWrapper(response);
                isRecordAccessLog = true;
            }
            RequestResponseContext requestResponseContext = RequestResponseContextHolder
                    .initContext(startTime, requestToUse, responseToUse, requestConfigProperties, responseConfigProperties);
            requestResponseContext.setOptLog(optLog);
            requestResponseContext.setEnableRecordPersistence(enableRecordPersistence);
            requestResponseContext.setEnableRecordLocal(enableRecordLocal);
            preprocess(requestResponseContext);

            try {
                filterChain.doFilter(requestToUse, responseToUse);
            } catch (Exception e) {
                log.error("request occurred exception.", e);

                // 构造框架response
                Response<String> frameworkResponse = null;
                if (e instanceof CheckedException) {
                    frameworkResponse = Response.failed((CheckedException) e);
                }
                if (null == frameworkResponse) {
                    frameworkResponse = Response.failed(SystemErrorCode.API_SERVER_ERROR);
                }

                // 先将响应状态设置为服务端错误
                responseToUse.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value());

                // 如果为RPC请求且响应中没有框架响应，则将响应添加到框架响应中，不必考虑响应body，因为状态码已经标记服务端错误了
                boolean rpcRequest = HttpServletUtil.isRpcRequest(requestToUse);
                String frameworkResponseHeader = responseToUse.getHeader(FrameworkCoreConstant.FRAMEWORK_RESPONSE_KEY);
                if (rpcRequest && null == frameworkResponseHeader) {
                    responseToUse.setHeader(FrameworkCoreConstant.FRAMEWORK_RESPONSE_KEY,
                            URLEncoder.encode(JsonObject.toString(frameworkResponse), StandardCharsets.UTF_8.name()));
                } else {
                    // 如果不是RPC请求，则将框架响应作为最终响应写到响应流中
                    String responseStr = JsonObject.fromObject(frameworkResponse).toString();
                    responseToUse.setCharacterEncoding(StandardCharsets.UTF_8.name());
                    responseToUse.setContentType(MediaType.APPLICATION_JSON_VALUE);
                    responseToUse.getOutputStream().write(responseStr.getBytes(StandardCharsets.UTF_8.name()));
                }
            } finally {
                if (isRecordAccessLog) {
                    afterRequest(requestMaxLen, responseMaxLen);
                }
            }
        } else {
            filterChain.doFilter(requestToUse, response);
        }
    }

    private void afterRequest(int requestMaxLen, int responseMaxLength) {
        RequestResponseContext context = RequestResponseContextHolder.getContext();
        HttpServletRequest originalRequest = context.getOriginalRequest();
        String responseContentType = context.getOriginalResponse().getContentType();
        if (VerifyUtil.isEmpty(responseContentType)) {
            originalRequest.setAttribute(FrameworkBaseConstant.REQUEST_RESPONSE_CONTEXT_KEY, context);
            log.info("response isn't ready.");
            return;
        }
        ContentCachingRequestWrapper requestWrapper =
                WebUtils.getNativeRequest(originalRequest, ContentCachingRequestWrapper.class);
        if (requestWrapper != null) {
            String contentType = context.getRequestContentType();
            if (HttpContentTypeUtil.isTextBody(contentType) || VerifyUtil.isEmpty(contentType)) {
                byte[] buf = requestWrapper.getContentAsByteArray();
                if (buf.length > 0) {
                    int length = Math.min(buf.length, requestMaxLen);
                    try {
                        context.setRequestBody(new String(buf, 0, length, requestWrapper.getCharacterEncoding()));
                    } catch (UnsupportedEncodingException ex) {
                        context.setRequestBody("[unknown]");
                        log.error("get request body error.", ex);
                    }
                }
            }
        } else {
            log.error("can't get request body.");
        }

        HttpServletResponse originalResponse = context.getOriginalResponse();
        String responseBodyType = originalResponse.getContentType();
        context.setResponseBodyType(responseBodyType);
        ContentCachingResponseWrapper responseWrapper = WebUtils.getNativeResponse(originalResponse, ContentCachingResponseWrapper.class);
        if (responseWrapper != null) {
            if (HttpContentTypeUtil.isTextBody(responseBodyType) || VerifyUtil.isEmpty(responseBodyType)) {
                InputStream contentInputStream = responseWrapper.getContentInputStream();
                try {
                    byte[] bytes = IoUtil.toLimitByteArray(contentInputStream, responseMaxLength);
                    context.setResponseBody(new String(bytes, 0, bytes.length, StandardCharsets.UTF_8));
                } finally {
                    IoUtil.close(contentInputStream);
                }
            }
            try {
                responseWrapper.copyBodyToResponse();
            } catch (IOException e) {
                log.error("copy response body to original response error.", e);
                throw new CheckedException(SystemErrorCode.API_SERVER_ERROR);
            }
        } else {
            log.error("can't get response body.");
        }
        accessLogHelper.asyncRecordAccessLog(context);
    }


    private boolean isRequireRecordLogUri(String uri, String method) {
        // 如果uri是记录请求日志/操作日志的请求，则直接通过，不需要记录
        boolean isUriMatch = AccessLogFacadeConstant.TRACE_LOG_RECORD_URI.equalsIgnoreCase(uri)
                || AccessLogFacadeConstant.OPT_LOG_RECORD_URI.equalsIgnoreCase(uri);
        if (isUriMatch && HttpMethod.POST.matches(method)) {
            return false;
        }

        // 如果配置了不需要记录的uri，则不记录日志
        RequestConfigProperties requestConfigProperties = frameworkBaseConfigProperties.getRequest();
        List<String> ignoreRecordRequestMethods = requestConfigProperties.getIgnoreRecordRequestMethods();
        if (ignoreRecordRequestMethods.contains(method)) {
            return false;
        }
        List<String> ignoreRecordTraceUris = requestConfigProperties.getIgnoreRecordTraceUris();
        for (String ignoreRecordUri : ignoreRecordTraceUris) {
            if (uri.toLowerCase().contains(ignoreRecordUri.toLowerCase())) {
                return false;
            }
        }
        return true;
    }

    /**
     * 前置处理，由于相关的安全管控已经在网关完成，所以这里不需要考虑是内部访问还是外部访问
     *
     * @param requestResponseContext 请求和响应上下文
     */
    private void preprocess(RequestResponseContext requestResponseContext) {
        HttpServletRequest request = requestResponseContext.getOriginalRequest();
        String method = request.getMethod();
        String uri = request.getRequestURI();
        Map<String, List<String>> allHeaders = HttpServletUtil.getAllHeaders(request);
        RequestConfigProperties requestConfigProperties = requestResponseContext.getRequestConfigProperties();
        requestResponseContext.setOriginalHeaders(allHeaders);
        requestResponseContext.setMethod(method);
        requestResponseContext.setUri(uri);
        requestResponseContext.setRequestContentType(request.getContentType());

        // 设置入访ip
        List<String> remoteHostAddr = Optional.ofNullable(allHeaders.get(requestConfigProperties.getRemoteHostHeaderKey()))
                .orElse(Collections.singletonList(request.getRemoteAddr()));
        requestResponseContext.setRemoteHostAddr(remoteHostAddr);

        String traceId = MDC.get(FrameworkCommonConstant.GLOBAL_LOG_TRACE_KEY);
        if (VerifyUtil.isEmpty(traceId)) {
            traceId = RandomUtil.random20LenId();
            MDC.put(FrameworkCommonConstant.GLOBAL_LOG_TRACE_KEY, traceId);
        }
        request.setAttribute(FrameworkCommonConstant.GLOBAL_LOG_TRACE_KEY, traceId);
        requestResponseContext.getOriginalResponse().setHeader(FrameworkCommonConstant.GLOBAL_LOG_TRACE_KEY, traceId);

        String previousAppName = CollectionUtil.getFirstElementQuietly(allHeaders.get(FrameworkCoreConstant.GLOBAL_PREVIOUS_APP_KEY),
                FrameworkCommonConstant.EMPTY_STRING);
        MDC.put(FrameworkCoreConstant.GLOBAL_PREVIOUS_APP_KEY, previousAppName);
        request.setAttribute(FrameworkCoreConstant.GLOBAL_PREVIOUS_APP_KEY, previousAppName);

        // 设置用户上下文信息
        String userContext = CollectionUtil.getFirstElementQuietly(allHeaders.get(FrameworkCoreConstant.USER_CONTEXT_HEADER_KEY));
        if (!VerifyUtil.isEmpty(userContext)) {
            String temp = null;
            try {
                temp = URLDecoder.decode(userContext, StandardCharsets.UTF_8.name());
            } catch (UnsupportedEncodingException ignored) {
            }
            SecurityContextHolder.setUserContext(temp);
        }

        // 设置系统版本号
        String systemVersion = CollectionUtil.getFirstElementQuietly(allHeaders.get(FrameworkCommonConstant.SYSTEM_VERSION_KEY));
        if (!VerifyUtil.isEmpty(systemVersion)) {
            SystemVersionContextHolder.setCurrentSystemVersion(Long.parseLong(systemVersion));
        }
    }
}